"""Weighted low rank approximation based on the SVD."""

import numpy as np

from . import svd


def weighted_lra(matrix, weight, rank, fisher_rank=1, eps=1e-8):
  # compute a rank (fisher_rank) SVD of (1 / fisher)
  inv_weight = weight + eps * (np.abs(weight) < eps)
  inv_weight = 1 / inv_weight
  inv_weight = np.sqrt(inv_weight)
  inv_weight_u, inv_weight_v = svd.svd(inv_weight, fisher_rank)
  low_rank_inv_weight = inv_weight_u @ inv_weight_v
  # compute a rank (rank - fisher_rank) SVD of weight / low_rank_inv_weight
  weight_rank = rank if fisher_rank == 1 else rank - fisher_rank
  weight_u, weight_v = svd.svd(matrix / low_rank_inv_weight, weight_rank)
  # final approximation is (inv_fisher_u @ inv_fisher_v) * (weight_u @ weight_v)
  return inv_weight_u, inv_weight_v, weight_u, weight_v
